import jax
import equinox as eqx
import equinox.nn as nn
import jax.numpy as jnp
import typing as tpImplementing GPT2 in JAX for fun 🦀🦀🦀
1 GPT2 for JAX 🚀
Explore the full project on the GitHub repository.
1.1 Context ✍️
This project involves rewriting XTTS in JAX to better understand its architecture and functionality. Originally developed by the now-defunct Coqai company, XTTS is a Text-to-Speech model. We’ll recreate its generative component using a GPT2 architecture—a decoder-only transformer—based on (Radford et al. 2019). The implementation closely follows this tutorial.
1.2 GPT2 in Text-to-Speech
1.2.1 What are we building?
Our goal is to generate sequences of tokens for audio synthesis. Specifically, we aim to produce “audio tokens,” small units of audio, discovered using a VQVAE. By learning to map text tokens to audio tokens, the model becomes multi-modal.
The final output sequences represent speech, which we convert into audio using HiFiGAN. Additionally, we enhance speech expressiveness (e.g., tone, speed) by feeding 1024-dimensional vectors representing the target speaker’s paralinguistic features.
1.2.2 Under the Hood
Masked Attention
Masked attention is the core mechanism for learning relationships between tokens. It determines which tokens influence others by projecting them into smaller dimensions and computing relationships. Masking ensures the model focuses only on prior tokens, preventing it from “seeing” future ones.
Studies classify attention patterns into:
1. Semantic: Tokens linked by meaning.
2. Linguistic: Tokens connected by grammar (e.g., verbs and nouns).
3. Rare Tokens: Infrequent but critical tokens.
Feedforward Layers
Feedforward layers mix outputs, add non-linearity via activation functions, and stack layers for hierarchical abstractions. The final output approximates a one-hot encoding in the token vocabulary, enabling token selection for sequential generation.
1.3 Goal 🎯
Implement a GPT2 architecture using Equinox and train it on TinyStories.
2 Model
We have a few things to implement from the ground up. The custom activation function, the forward layer, the masked attention. We then package this up in a nice layer that we can stack, and finally wrap all these stacks into a GPT2 !
We can start by importing our favorite libraries 🥰
2.1 Configuration file
Because of the size of our model, we’re going to be passing down lots of arguments. To avoid having a long unreadable list of parameters we can define a “dataclass” that will allow us to simply pass a config down to the model.
Feel free to experiment with various settings !
from dataclasses import dataclass
@dataclass
class GPTConfig:
block_size: int = 128
vocab_size: int = (
50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
)
n_layer: int = 12
n_head: int = 6
n_embd: int = 200
dropout: float = 0.0
bias: bool = False #2.2 SwiGLU Activation Function
We start by implementing the SwiGLU activation function, introduced in (Shazeer 2020), a powerful variant of GLU.
2.2.1 Why SwiGLU?
SwiGLU dynamically adjusts its activation based on the input. Think of it like a railway switch—redirecting the “activation path” when the input carries different information. This gives the network greater flexibility and control, leading to better performance.
For more details, see this explanation by Boudefel.
Below is a visualization of the Swish function, \(x \times \text{sigmoid}(x)\), which plays a role in SwiGLU:
class SwiGLU(eqx.Module):
W: nn.Linear
V: nn.Linear
b: jax.Array
c: jax.Array
def __init__(self, input_dim, output_dim, key):
key1, key2, key3, key4 = jax.random.split(key, 4)
self.W = nn.Linear(input_dim, output_dim, key=key1)
self.V = nn.Linear(input_dim, output_dim, key=key2)
self.b = jax.random.normal(key3, (output_dim))
self.c = jax.random.normal(key4, (output_dim))
def __call__(self, x):
return jax.nn.swish((self.W(x) + self.b) * (self.V(x) + self.c))Code
key = jax.random.PRNGKey(69)
mod = SwiGLU(10, 4, key)
x = jnp.ones(10)
print(mod(x).shape)2.3 MLP
We can now move onto the multilayer perceptron, which we mentionned earlier as the feedforward part of our network. Because the model is big and we want to make sure that it doesn’t just “memorize” things, we include dropout which pushes the model to avoid relying on singular neurons / data flowing through for information.
✨ You’ll also notice that since our SwiGLU has two linear layers in it, in reality each MLP that we’ll use uses 4 layers !!
class MLP(eqx.Module):
layers: list
def __init__(self, config, key):
key1, key2, key3 = jax.random.split(key, 3)
self.ff1 = nn.Linear(
config.n_embd, 4 * config.n_embd, use_bias=config.bias, key=key1
)
self.act = SwiGLU(4 * config.n_embd, 4 * config.n_embd, key=key2)
self.ff2 = nn.Linear(
4 * config.n_embd, config.n_embd, use_bias=config.bias, key=key3
)
self.drop = nn.Dropout(config.dropout)
def __call__(self, x):
y = self.ff1(x)
y = self.act(y)
y = self.ff2(y)
return self.drop(y)Again, we can compare with their implementation to make sure we’re close enough.
Code
class MLPTheirs(eqx.Module):
c_fc: eqx.nn.Linear
swiglu: SwiGLU
c_proj: eqx.nn.Linear
dropout: eqx.nn.Dropout
def __init__(self, config, key):
lkey1, lkey2, skey = jax.random.split(key, 3)
self.c_fc = eqx.nn.Linear(
config.n_embd, 4 * config.n_embd, use_bias=config.bias, key=lkey1
)
self.swiglu = SwiGLU(4 * config.n_embd, 4 * config.n_embd, skey)
self.c_proj = eqx.nn.Linear(
4 * config.n_embd, config.n_embd, use_bias=config.bias, key=lkey2
)
self.dropout = eqx.nn.Dropout(config.dropout)
def __call__(self, x):
x = jax.vmap(self.c_fc)(x)
x = jax.vmap(self.swiglu)(x)
x = jax.vmap(self.c_proj)(x)
x = self.dropout(x)
return xCode
config = GPTConfig()
key = jax.random.PRNGKey(69)
mlp = MLP(config, key)
mlp_theirs = MLPTheirs(config, key)
x = jax.random.normal(key, (100, config.n_embd))
res = jax.vmap(mlp)(x)
res_theirs = mlp_theirs(x)
average_diff = jnp.mean(res_theirs)
print(average_diff)2.4 Masked attention
Moving onto one of the more complicated aspects of the model, but in the end it simply learns to output which tokens are more important with each other. There are plenty of fantastic tutorials out there for better understanding the underlying concept, notably : Transformers explained visually
import math
class CausalSelfAttention(eqx.Module):
attnk: nn.Linear
attnq: nn.Linear
attnv: nn.Linear
proj: nn.Linear
resid_dropout: nn.Dropout
attn_dropout: nn.Dropout
mask: jax.Array = eqx.field(static=True)
def __init__(self, config, key):
key1, key2, key3, key4 = jax.random.split(key, 4)
self.attnk = nn.Linear(
config.n_embd, config.n_embd, use_bias=config.bias, key=key1
)
self.attnv = nn.Linear(
config.n_embd, config.n_embd, use_bias=config.bias, key=key2
)
self.attnq = nn.Linear(
config.n_embd, config.n_embd, use_bias=config.bias, key=key3
)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.proj = nn.Linear(
config.n_embd, config.n_embd, use_bias=config.bias, key=key4
)
self.mask = jnp.tril(jnp.ones((config.block_size, config.block_size)))
# Could play arround with the different attention score calculations (Baidhu ?)
# X is an embedding, it should self attend.
def __call__(self, x):
# x = jnp.swapaxes(x, -1, -2)
T, C = x.shape # Seq length and embedding dim.
q = jax.vmap(self.attnq)(x)
k = jax.vmap(self.attnk)(x)
v = jax.vmap(self.attnv)(x)
att = jnp.matmul(q, jnp.transpose(k)) / math.sqrt(jnp.shape(k)[-1])
att = jnp.where(
jax.numpy.equal(jax.lax.stop_gradient(self.mask[:T, :T]), 0),
float("-inf"),
att,
)
att = jax.nn.softmax(att, axis=-1)
att = self.attn_dropout(att)
y = jnp.matmul(att, v)
y = jax.vmap(self.proj)(y)
y = self.resid_dropout(y)
return ySmall check…
Code
config = GPTConfig()
key = jax.random.PRNGKey(69)
mlp = CausalSelfAttention(config, key)
print(mlp(jax.random.normal(key, (100, config.n_embd))).shape)2.5 Block
Ok ! Now that we have the component parts of what we call a “block” we can assemble them. This will then be stacked to get as many layers of abstraction as we wish. In our case we will stack it 12 times as per the GPTConfig we defined.
class Block(eqx.Module):
norm: nn.LayerNorm
attn: CausalSelfAttention
mlp: MLP
def __init__(self, config, key):
key1, key2 = jax.random.split(key, 2)
self.norm = nn.LayerNorm(config.n_embd, use_bias=config.bias)
self.attn = CausalSelfAttention(config, key=key1)
self.mlp = MLP(config, key=key2)
def __call__(self, x):
y = jax.vmap(self.norm)(x)
y = self.attn(
y
) # Can't vmap as the whole point is exchange info between tokens.
x = y + x
y = jax.vmap(self.norm)(x)
y = jax.vmap(self.mlp)(y)
x = y + x
return xCan compare with their work.
Code
class BlockTheirs(eqx.Module):
ln_1: eqx.nn.LayerNorm
attn: CausalSelfAttention
ln_2: eqx.nn.LayerNorm
mlp: MLP
def __init__(self, config, key):
ckey, mkey = jax.random.split(key, 2)
self.ln_1 = eqx.nn.LayerNorm(config.n_embd, use_bias=config.bias)
self.attn = CausalSelfAttention(config, ckey)
self.ln_2 = eqx.nn.LayerNorm(config.n_embd, use_bias=config.bias)
self.mlp = MLPTheirs(config, mkey)
def __call__(self, x):
x = x + self.attn(jax.vmap(self.ln_1)(x))
x = x + self.mlp(jax.vmap(self.ln_2)(x))
return xCode
config = GPTConfig()
key = jax.random.PRNGKey(69)
mlp = Block(config, key)
mlp_theirs = BlockTheirs(config, key)
x = jax.random.normal(key, (100, config.n_embd))
res = mlp(x)
res_their = mlp_theirs(x)
average_diff = jnp.mean(res - res_their)
print(average_diff)We can finally add the embeddings to our model, which are the maps that send tokens to the dimension that the model works with, i.e. 1024 dims.
class GPT(eqx.Module):
wte: nn.Embedding # Token embeddings
wpe: nn.Embedding # Positional embeddings
drop: nn.Dropout
layers: list
norm: nn.LayerNorm
def __init__(self, config, key):
key1, key2, key3, key4 = jax.random.split(key, 4)
self.wte = nn.Embedding(config.vocab_size, config.n_embd, key=key1)
self.wpe = nn.Embedding(config.block_size, config.n_embd, key=key2)
self.drop = nn.Dropout(config.dropout)
self.layers = [Block(config, key) for _ in range(config.n_layer)]
self.norm = nn.LayerNorm(config.n_embd, use_bias=config.bias)
def __call__(self, token_ids):
(t,) = token_ids.shape
# Should use better positional embeddings with cos and sin.
pos = jnp.arange(0, t, dtype=jnp.int64)
tok_emb = jax.vmap(self.wte)(token_ids)
pos_emb = jax.vmap(self.wpe)(pos)
# Dropout at the first layer ? Seems a bit aggressive...
x = self.drop(tok_emb + pos_emb)
for block in self.layers:
x = block(x)
x = jax.vmap(self.norm)(x)
return xComparing with their work…
Code
class TransformerLayerTheirs(eqx.Module):
_config: GPTConfig = eqx.field(static=True)
wte: eqx.nn.Embedding
wpe: eqx.nn.Embedding
drop: eqx.nn.Dropout
h: list
ln_f: eqx.nn.LayerNorm
def __init__(self, config, key):
ekey, pkey, hkey, fkey = jax.random.split(key, 4)
assert config.vocab_size is not None
assert config.block_size is not None
self._config = config
self.wte = eqx.nn.Embedding(config.vocab_size, config.n_embd, key=ekey)
self.wpe = eqx.nn.Embedding(config.block_size, config.n_embd, key=pkey)
self.drop = eqx.nn.Dropout(config.dropout)
self.h = [Block(config, hkey) for _ in range(config.n_layer)]
self.ln_f = eqx.nn.LayerNorm(config.n_embd, use_bias=config.bias)
self.lm_head = nn.Linear(
config.n_embd, config.vocab_size, use_bias=False, key=key2
)
def __call__(self, idx):
(t,) = idx.shape
assert (
t <= self._config.block_size
), f"Cannot forward sequence of length {t}, block size is only {self._config.block_size}"
pos = jnp.arange(0, t, dtype=jnp.int64)
tok_emb = jax.vmap(self.wte)(idx) # token embeddings of shape (t, n_embd)
pos_emb = jax.vmap(self.wpe)(pos) # position embeddings of shape (t, n_embd)
x = self.drop(tok_emb + pos_emb)
for block in self.h:
x = block(x)
x = jax.vmap(self.ln_f)(x)
return xCode
config = GPTConfig()
key = jax.random.PRNGKey(69)
mlp = TransformerLayer(config, key)
mlp_theirs = TransformerLayerTheirs(config, key)
x = jax.random.normal(key, (100))
x = np.array([0, 0, 1, 2, 3, 4, 5, 6])
res = mlp(x)
res_their = mlp_theirs(x)
average_diff = np.mean(res - res_their)
print(average_diff)We’ll then
class GPT(eqx.Module):
transformer: TransformerLayer
lm_head: nn.Linear
def __init__(self, config, key):
key1, key2 = jax.random.split(key, 2)
self.transformer = TransformerLayer(config, key1)
self.lm_head = nn.Linear(
config.n_embd, config.vocab_size, use_bias=False, key=key2
)
def __call__(self, token_ids):
y = self.transformer(token_ids)
logits = jax.vmap(self.lm_head)(y)
return logitsWe can compare our method with the one implemented in nanoJAXGPT:
Code
class CausalSelfAttentionTheirs(eqx.Module):
c_attn: eqx.nn.Linear
c_proj: eqx.nn.Linear
attn_dropout: eqx.nn.Dropout
resid_dropout: eqx.nn.Dropout
bias: jax.Array = eqx.field(static=True)
_config: GPTConfig = eqx.field(static=True)
def __init__(self, config, key):
assert config.n_embd % config.n_head == 0
# PRNGKey
lkey1, lkey2 = jax.random.split(key, 2)
# key, query, value projections for all heads, but in a batch
self.c_attn = eqx.nn.Linear(
config.n_embd, 3 * config.n_embd, use_bias=config.bias, key=lkey1
)
# output projection
self.c_proj = eqx.nn.Linear(
config.n_embd, config.n_embd, use_bias=config.bias, key=lkey2
)
# regularization
self.attn_dropout = eqx.nn.Dropout(config.dropout)
self.resid_dropout = eqx.nn.Dropout(config.dropout)
self._config = config
# causal mask to ensure that attention is only applied to the left in the ijnput sequence
# Has been made a buffer by using lax.stop_gradient whenever it is used.
# Immutability calls for reshape, plus there is no view for jnp (or numpy) arrays.
self.bias = jnp.tril(jnp.ones((config.block_size, config.block_size))).reshape(
1, 1, config.block_size, config.block_size
)
def __call__(self, x):
T, C = jnp.shape(x) # sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = jnp.split(jax.vmap(self.c_attn)(x), 3, axis=1)
# Immutability calls for reshape, plus there is no view for jnp (or numpy) arrays.
k = jnp.swapaxes(
k.reshape(T, self._config.n_head, C // self._config.n_head), 0, 1
) # (nh, T, hs)
q = jnp.swapaxes(
q.reshape(T, self._config.n_head, C // self._config.n_head), 0, 1
) # (nh, T, hs)
v = jnp.swapaxes(
v.reshape(T, self._config.n_head, C // self._config.n_head), 0, 1
) # (nh, T, hs)
# manual implementation of attention
att = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) / math.sqrt(jnp.shape(k)[-1])
# Note: Added the stop_gradient just to be safe, I see no update rule acting on the bias inside the
# forward pass.
att = jnp.where(
jax.lax.stop_gradient(self.bias[:, :, :T, :T]) == 0, float("-inf"), att
)
att = jax.nn.softmax(att, axis=-1)
att = self.attn_dropout(att)
y = jnp.matmul(att, v) # (nh, T, T) x (nh, T, hs) -> (nh, T, hs)
# Reshaping with Immutability creates a new copy
y = jnp.swapaxes(y, 1, 2).reshape(
T, C
) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(jax.vmap(self.c_proj)(y))
return yCode
key = jax.random.PRNGKey(69)
key1, key2, key3 = jax.random.split(key, 3)
gpt_config = GPTConfig()
ours = CausalSelfAttention(gpt_config, key1)
# theirs = CausalSelfAttentionTheirs(gpt_config, key2)
eq = nn.MultiheadAttention(
1,
query_size=config.n_embd,
value_size=config.n_embd,
key_size=config.n_embd,
output_size=config.n_embd,
key=key1,
)
# Pass sequence and hiddeng dim
x = jax.random.normal(key3, (100, config.n_embd))
y_ours = ours(x)
y_theirs = eq(x, x, x)
average_diff = jax.numpy.mean(y_ours - y_theirs)
average_std = jax.numpy.std(y_ours - y_theirs)
# print(y_ours[0][3])
print(average_diff)
print(average_std)
# print(y_theirs[0][3])Values seem to be coming out the same, it’s interesting to see this distribution which is biased in some sense - TODO check why this is the case
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1)
axs.plot(np.reshape(y_ours, (-1)))
axs.plot(np.reshape(y_theirs, -1))
plt.show()